Connected to chess (Python 3.11.5)
"""Initial Exploration 05"""
# %% [markdown]
# # Initial Explorations 05
# Attempting to follow initial explorations from Neel Nanda here. This time, I'm doing it with the multigame data like Neel did.
'Initial Exploration 05'
# Imports
import einops
from IPython.display import display
import pysvelte
import torch
import transformer_lens
from transformer_lens import HookedTransformer, ActivationCache
from transformers import PreTrainedTokenizerFast
from austin_plotly import imshow, line
from mech_interp.visualizations import (
get_legal_tokens,
plot_board_log_probs,
get_game_prefix_up_to_token_idx,
map_token_to_move_index,
plot_valid_moves,
plot_single_board,
)
from mech_interp.mappings import logitId2token, logitId2square
from mech_interp.utils import uci_to_board
from mech_interp.chess_dataset import ChessDataImporter
from mech_interp.fixTL import make_official
torch.set_grad_enabled(False)
MODEL_NAME = make_official()
# pysvelte.Hello(name='austin')
Setup¶
tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
model = HookedTransformer.from_pretrained(MODEL_NAME, tokenizer)
Loaded pretrained model AustinD/gpt2-chess-uci-hooked into HookedTransformer
NUM_SEQS = 100
GAME_INDEX = 0
MOVE_LIMIT = 126
DATA_SET_SOURCE = "dev"
board_seqs_int = torch.load(
f"./chess_data/tensorized/embeded_games_127_{DATA_SET_SOURCE}.pt"
)[:NUM_SEQS]
game_indices = torch.load(
f"./chess_data/tensorized/embeded_games_127_{DATA_SET_SOURCE}_indices.pt"
)[:NUM_SEQS]
cdi = ChessDataImporter(DATA_SET_SOURCE)
board_seqs_string = [cdi.games[i] for i in game_indices]
board_seqs_tokens = [
tokenizer.tokenize(cdi.games[i], add_special_tokens=True) for i in game_indices
]
board_seqs_offset_mappings = tokenizer.batch_encode_plus(
board_seqs_string, return_offsets_mapping=True
)["offset_mapping"]
encoded = tokenizer.batch_encode_plus(board_seqs_string, return_offsets_mapping=True)
valid_moves_list = []
for i in range(NUM_SEQS):
tokens = board_seqs_tokens[i]
uci_moves = board_seqs_string[i]
offset_mapping = board_seqs_offset_mappings[i]
board_stack = uci_to_board(
uci_moves=uci_moves,
force=True,
fail_silent=True,
verbose=False,
as_board_stack=True,
)
temp_valid_moves = [
get_legal_tokens(pos, uci_moves, tokens, offset_mapping, board_stack)
for pos in range(len(tokens) - 1)
]
valid_moves_list.append(temp_valid_moves)
print(list(zip(list(range(len(valid_moves_list[0]))), valid_moves_list[0]))[:1])
print([[len(valid_moves_list[i][j]) for j in range(-10, 0)] for i in range(NUM_SEQS)])
[(0, ['a2', 'b1', 'b2', 'c2', 'd2', 'e2', 'f2', 'g1', 'g2', 'h2'])] [[6, 6, 8, 16, 6, 11, 8, 1, 6, 2], [8, 6, 2, 1, 8, 6, 7, 11, 8, 9], [1, 18, 1, 4, 1, 13, 6, 11, 3, 1], [1, 3, 6, 7, 6, 10, 6, 7, 6, 10], [8, 7, 2, 1, 7, 7, 1, 1, 6, 17], [9, 5, 1, 3, 8, 7, 7, 4, 9, 2], [8, 16, 9, 7, 2, 1, 8, 5, 8, 12], [9, 5, 8, 5, 9, 12, 8, 7, 3, 7], [6, 10, 6, 4, 6, 9, 6, 1, 6, 10], [2, 7, 2, 2, 12, 6, 8, 4, 12, 8], [2, 3, 6, 11, 9, 7, 7, 16, 11, 10], [9, 2, 6, 10, 9, 1, 6, 5, 8, 6], [5, 10, 6, 4, 5, 1, 6, 1, 5, 1], [6, 21, 2, 2, 6, 1, 5, 12, 1, 1], [9, 9, 1, 5, 9, 12, 7, 4, 8, 14], [9, 5, 6, 14, 9, 1, 6, 2, 9, 1], [8, 16, 8, 13, 7, 4, 8, 5, 7, 8], [10, 7, 1, 4, 9, 15, 1, 1, 9, 15], [10, 8, 7, 11, 9, 20, 7, 10, 9, 12], [9, 5, 9, 2, 9, 5, 10, 12, 9, 9], [3, 1, 8, 7, 8, 12, 2, 1, 8, 7], [7, 11, 1, 1, 1, 12, 1, 2, 6, 21], [1, 10, 2, 1, 7, 7, 6, 7, 7, 6], [8, 8, 8, 20, 1, 15, 1, 1, 7, 5], [8, 6, 1, 10, 1, 13, 7, 5, 1, 3], [1, 5, 9, 4, 2, 2, 7, 8, 6, 2], [9, 9, 8, 1, 9, 4, 8, 16, 9, 10], [12, 6, 9, 3, 12, 7, 9, 5, 12, 6], [8, 8, 2, 4, 1, 1, 7, 9, 8, 7], [8, 2, 8, 10, 8, 3, 9, 11, 8, 5], [10, 13, 7, 7, 10, 5, 8, 2, 10, 5], [10, 1, 10, 3, 10, 1, 10, 2, 9, 1], [9, 6, 7, 2, 4, 3, 6, 12, 8, 7], [1, 3, 7, 4, 8, 10, 3, 7, 7, 7], [10, 15, 8, 5, 10, 1, 8, 6, 9, 10], [12, 3, 13, 1, 11, 6, 13, 4, 11, 1], [7, 7, 1, 4, 7, 4, 4, 1, 7, 4], [8, 2, 8, 16, 8, 6, 8, 2, 8, 2], [1, 1, 6, 11, 6, 8, 1, 12, 1, 2], [10, 6, 10, 10, 10, 7, 10, 5, 2, 13], [9, 12, 9, 4, 10, 17, 1, 1, 9, 1], [11, 14, 8, 10, 10, 2, 9, 9, 11, 7], [5, 14, 7, 15, 4, 1, 6, 9, 4, 9], [7, 2, 8, 11, 6, 3, 6, 8, 6, 11], [3, 2, 8, 1, 9, 8, 8, 4, 9, 3], [3, 5, 8, 1, 8, 8, 8, 2, 1, 2], [7, 1, 7, 10, 6, 10, 7, 9, 6, 7], [10, 1, 10, 9, 9, 2, 8, 15, 9, 2], [10, 5, 1, 12, 9, 3, 8, 7, 9, 5], [7, 17, 7, 7, 7, 8, 7, 2, 7, 16], [6, 1, 6, 13, 5, 13, 6, 11, 2, 1], [10, 6, 1, 3, 10, 7, 8, 1, 11, 5], [4, 4, 6, 8, 5, 7, 1, 2, 6, 9], [8, 3, 6, 14, 8, 7, 7, 15, 8, 9], [1, 2, 6, 8, 1, 2, 6, 12, 5, 8], [10, 2, 9, 9, 1, 8, 8, 11, 9, 7], [11, 6, 1, 5, 10, 8, 8, 10, 10, 13], [6, 17, 1, 2, 6, 1, 2, 3, 5, 20], [11, 11, 9, 7, 11, 17, 1, 1, 10, 10], [8, 17, 9, 4, 8, 7, 8, 15, 6, 5], [10, 12, 6, 2, 1, 11, 1, 2, 9, 17], [8, 6, 6, 1, 8, 9, 6, 5, 8, 2], [9, 8, 7, 2, 8, 6, 7, 8, 8, 8], [7, 11, 8, 1, 7, 2, 8, 2, 7, 10], [8, 11, 5, 7, 8, 17, 1, 2, 8, 12], [9, 8, 9, 13, 10, 18, 9, 8, 10, 19], [10, 11, 9, 3, 10, 15, 1, 2, 10, 6], [10, 6, 10, 11, 11, 7, 10, 12, 2, 5], [10, 9, 8, 4, 10, 9, 8, 4, 2, 11], [7, 12, 10, 8, 1, 12, 9, 4, 6, 10], [1, 4, 6, 7, 6, 1, 1, 2, 1, 4], [7, 6, 7, 2, 2, 7, 6, 9, 6, 5], [3, 1, 7, 7, 10, 5, 6, 10, 10, 7], [10, 10, 11, 5, 2, 8, 2, 2, 9, 12], [7, 2, 7, 6, 7, 16, 6, 7, 7, 13], [7, 1, 2, 1, 7, 7, 7, 5, 7, 7], [11, 6, 2, 8, 10, 6, 8, 11, 10, 11], [8, 12, 9, 7, 8, 9, 9, 8, 8, 12], [2, 3, 9, 2, 8, 1, 9, 10, 8, 5], [9, 2, 1, 2, 1, 15, 7, 15, 9, 9], [9, 12, 9, 3, 9, 14, 8, 6, 9, 6], [9, 8, 7, 8, 9, 8, 1, 5, 1, 6], [7, 13, 6, 21, 3, 1, 6, 12, 1, 14], [10, 11, 10, 2, 9, 1, 9, 1, 8, 8], [9, 12, 8, 13, 9, 12, 2, 2, 9, 12], [9, 2, 8, 2, 7, 9, 6, 1, 7, 4], [7, 14, 6, 5, 7, 7, 1, 1, 1, 11], [1, 5, 6, 9, 5, 13, 6, 6, 5, 8], [11, 9, 1, 2, 11, 8, 2, 1, 11, 8], [8, 9, 2, 5, 8, 5, 6, 8, 8, 8], [1, 8, 8, 4, 9, 9, 8, 4, 10, 10], [8, 11, 6, 13, 8, 6, 6, 19, 8, 14], [11, 5, 10, 11, 11, 8, 9, 7, 11, 8], [9, 5, 7, 5, 1, 1, 7, 5, 3, 9], [8, 16, 7, 7, 8, 14, 7, 6, 8, 1], [11, 11, 9, 1, 11, 7, 9, 7, 11, 12], [6, 7, 9, 9, 3, 1, 9, 18, 2, 2], [1, 1, 6, 19, 1, 3, 6, 7, 8, 12], [9, 9, 10, 2, 1, 2, 1, 5, 7, 11], [10, 8, 1, 4, 2, 1, 9, 1, 2, 6]]
logits, cache = model.run_with_cache(board_seqs_int[GAME_INDEX])
uci_moves = board_seqs_string[GAME_INDEX]
print(f"uci_moves: {uci_moves}")
print(f"uci_moves tokenized: {model.to_str_tokens(uci_moves)}")
print(f"logits.shape: {logits.shape}")
uci_moves: d2d4 d7d5 c2c4 d5c4 e2e4 e7e5 g1f3 f8b4 c1d2 b4d2 b1d2 b8c6 f1c4 d8f6 e1g1 e5d4 b2b4 a7a6 a2a4 g8e7 b4b5 c6a5 c4d3 c7c5 d1c2 b7b6 e4e5 f6h6 d2e4 e8g8 e4g5 e7g6 a1b1 a6b5 b1b5 g6e5 d3h7 g8h8 f3e5 h6g5 f2f4 g5h6 h7e4 a8b8 f1f3 h8g8 f4f5 f8e8 e5g4 h6g5 f3g3 c8b7 g4f2 g5f4 b5b6 b7e4 b6b8 e8b8 f2e4 f4e5 g3g4 f7f6 h2h4 uci_moves tokenized: ['<s>', 'd2', 'd4', 'd7', 'd5', 'c2', 'c4', 'd5', 'c4', 'e2', 'e4', 'e7', 'e5', 'g1', 'f3', 'f8', 'b4', 'c1', 'd2', 'b4', 'd2', 'b1', 'd2', 'b8', 'c6', 'f1', 'c4', 'd8', 'f6', 'e1', 'g1', 'e5', 'd4', 'b2', 'b4', 'a7', 'a6', 'a2', 'a4', 'g8', 'e7', 'b4', 'b5', 'c6', 'a5', 'c4', 'd3', 'c7', 'c5', 'd1', 'c2', 'b7', 'b6', 'e4', 'e5', 'f6', 'h6', 'd2', 'e4', 'e8', 'g8', 'e4', 'g5', 'e7', 'g6', 'a1', 'b1', 'a6', 'b5', 'b1', 'b5', 'g6', 'e5', 'd3', 'h7', 'g8', 'h8', 'f3', 'e5', 'h6', 'g5', 'f2', 'f4', 'g5', 'h6', 'h7', 'e4', 'a8', 'b8', 'f1', 'f3', 'h8', 'g8', 'f4', 'f5', 'f8', 'e8', 'e5', 'g4', 'h6', 'g5', 'f3', 'g3', 'c8', 'b7', 'g4', 'f2', 'g5', 'f4', 'b5', 'b6', 'b7', 'e4', 'b6', 'b8', 'e8', 'b8', 'f2', 'e4', 'f4', 'e5', 'g3', 'g4', 'f7', 'f6', 'h2', 'h4'] logits.shape: torch.Size([1, 127, 77])
resid_decomp, component_labels = cache.decompose_resid(
-1, apply_ln=True, return_labels=True, incl_embeds=False, mode="attn"
)
print(f"W_U.shape: {model.W_U.shape}") # [d_model, d_vocab]
# [comp, batch, pos, d_model]
print(f"resid_decomp.shape: {resid_decomp.shape}")
decomp_logits = resid_decomp @ model.W_U
decomp_logits = decomp_logits.squeeze(1) # remove batch dim
print(f"decomp_logits.shape: {decomp_logits.shape} #[component, pos, d_vocab]")
print(f"decomp_logits logit id: {torch.argmax(decomp_logits[0,-1])}")
W_U.shape: torch.Size([768, 77]) resid_decomp.shape: torch.Size([12, 1, 127, 768]) decomp_logits.shape: torch.Size([12, 127, 77]) #[component, pos, d_vocab] decomp_logits logit id: 64
# Layer-wise Logits
board_states = uci_to_board(
uci_moves,
force=True,
as_board_stack=True,
)
# encoded = tokenizer.encode_plus(uci_moves, return_offsets_mapping=True)
print(len(encoded["offset_mapping"]))
100
Logit Lens by Pos¶
Idea¶
Broadly, a GPT operates as follows:
- Project input tokens from vocab space into embedding space
- Modify the embedded vector multiple times (this is done by each component of each layer)
- Project the resulting embedded vector back to vocab space
For a GPT to work well, the embedded vector must encode relevant information. The Logit Lens technique (LL) projects the embedded (residual) vector into vocab space after being modified by module/component $k$.
for comp_idx in range(10): # np.random.randint(1,100,[4]):
move_idx = map_token_to_move_index(
uci_moves, comp_idx, offset_mapping=encoded["offset_mapping"][GAME_INDEX]
)
if comp_idx > 2 and comp_idx % 2 == 1:
prefix = get_game_prefix_up_to_token_idx(
comp_idx, encoded["offset_mapping"][GAME_INDEX], uci_moves
)
print(prefix)
plot_board_log_probs(prefix.strip(), tokenizer, logits[0, :comp_idx])
# print(game_states[move_idx-1].unicode())
if len(board_states[move_idx - 1].move_stack) > 0:
print(board_states[move_idx - 1].peek())
imshow(
decomp_logits[:, comp_idx],
y=component_labels,
x=[f"{t} ({i})" for t, i in zip(logitId2token, list(range(0, 77)))],
xaxis="Token",
yaxis="Component",
title=f"Logit Lens for board state after position/token {comp_idx}",
)
d2d4
d2d4
d2d4
d2d4 d7d5
d7d5
d7d5
d2d4 d7d5 c2c4
c2c4
c2c4
d2d4 d7d5 c2c4 d5c4
d5c4
big_logits, big_cache = model.run_with_cache(board_seqs_int[:NUM_SEQS])
small_logits, small_cache = model.run_with_cache(board_seqs_string[:5])
big_log_probs = big_logits.log_softmax(-1)
small_log_probs = small_logits.log_softmax(-1)
for comp_idx in [0]: # range(resid_decomp.shape[0]): #for each component
imshow(
einops.rearrange(decomp_logits[comp_idx], "tok pos -> pos tok"),
x=[
f'{i}: "{logitId2token[id]}" ({id})'
for i, id in enumerate(encoded["input_ids"][GAME_INDEX])
],
yaxis="Token Id",
xaxis="Input Position",
title=f"Component Logit Lens by Posistion: {component_labels[comp_idx]}",
labels=dict(x="Position", y="Token"),
# color_continuous_scale='RdBu'
)
for comp_idx in [0]: # range(resid_decomp.shape[0]):
imshow(
decomp_logits[comp_idx][:, encoded["input_ids"][GAME_INDEX]],
# width=1000, height=1000,
x=[
f'{i}: "{logitId2token[id]}" ({id})'
for i, id in enumerate(encoded["input_ids"][GAME_INDEX])
],
xaxis="Token (sorted by pos, with possible duplicates)",
yaxis="Position",
title=f"Component Logit Lens by Pos (sorted): {component_labels[comp_idx]}",
labels=dict(y="Position", x="Token"),
)
plot_board_log_probs(
board_seqs_string[GAME_INDEX],
tokenizer,
small_logits[GAME_INDEX],
)
TEXT_INDEX = 0
fig = plot_board_log_probs(
board_seqs_string[TEXT_INDEX], tokenizer, small_logits[TEXT_INDEX], return_fig=True
)
fig.write_html("Game0.html", include_plotlyjs="cdn")
fig.show()
big_patterns = big_cache.stack_activation("pattern")
print(big_patterns.shape, "# [n_layer, n_head, dest, src]")
patterns = big_patterns[:, TEXT_INDEX, :]
patterns = einops.rearrange(patterns, "layer head dest src -> dest src (layer head)")
print(patterns.shape, "# [dext, src, (layer head)]")
torch.Size([12, 100, 12, 127, 127]) # [n_layer, n_head, dest, src] torch.Size([127, 127, 144]) # [dext, src, (layer head)]
if PYSVELT_WORKING := False:
for layer in range(1):
LH_display_range = slice(12 * layer, 12 * (layer + 1))
TOK_display_range = slice(1, 30)
head_patterns = pysvelte.AttentionMulti(
attention=patterns[TOK_display_range, TOK_display_range, LH_display_range],
tokens=[
f"{t} ({i+TOK_display_range.start})"
for i, t in enumerate(tokenizer.tokenize(uci_moves)[TOK_display_range])
],
head_labels=[
f"L{l}H{h}"
for l in range(model.cfg.n_layers)
for h in range(model.cfg.n_heads)
][LH_display_range],
)
head_patterns.show()
Visualizing Valid moves¶
Here's a "Gotcha!": tokenizer.tokenize simply tokenizes the string, but encoding adds the special token
if VALIDATE_INDICES := True: # Verifying index funcs work as expected
from mech_interp.visualizations import (
map_token_to_move_offsets,
determine_move_phase,
)
for comp_idx in range(10):
tok = tokenizer.tokenize(uci_moves, add_special_tokens=True)[comp_idx]
start, end = map_token_to_move_offsets(
comp_idx, offset_mapping=encoded["offset_mapping"][GAME_INDEX]
)
phase = determine_move_phase(uci_moves[start:end], tok)
mv_index = map_token_to_move_index(
uci_moves,
comp_idx,
offset_mapping=encoded["offset_mapping"][GAME_INDEX],
zero_index=True,
)
print(
f'ph: {f"{phase:<7}"}',
f"tok: {tok}",
f"mvidx: {mv_index}->{uci_moves.split()[mv_index]}",
f"rng={(start,end)}",
f"uci: {uci_moves[:start]} >>{uci_moves[start:end]}<< {uci_moves[end:end+5]}",
sep=" ",
)
ph: special tok: <s> mvidx: 0->d2d4 rng=(0, 4) uci: >>d2d4<< d7d5 ph: from tok: d2 mvidx: 0->d2d4 rng=(0, 4) uci: >>d2d4<< d7d5 ph: to tok: d4 mvidx: 0->d2d4 rng=(0, 4) uci: >>d2d4<< d7d5 ph: from tok: d7 mvidx: 1->d7d5 rng=(5, 9) uci: d2d4 >>d7d5<< c2c4 ph: to tok: d5 mvidx: 1->d7d5 rng=(5, 9) uci: d2d4 >>d7d5<< c2c4 ph: from tok: c2 mvidx: 2->c2c4 rng=(10, 14) uci: d2d4 d7d5 >>c2c4<< d5c4 ph: to tok: c4 mvidx: 2->c2c4 rng=(10, 14) uci: d2d4 d7d5 >>c2c4<< d5c4 ph: from tok: d5 mvidx: 3->d5c4 rng=(15, 19) uci: d2d4 d7d5 c2c4 >>d5c4<< e2e4 ph: to tok: c4 mvidx: 3->d5c4 rng=(15, 19) uci: d2d4 d7d5 c2c4 >>d5c4<< e2e4 ph: from tok: e2 mvidx: 4->e2e4 rng=(20, 24) uci: d2d4 d7d5 c2c4 d5c4 >>e2e4<< e7e5
is_valid_move = torch.zeros_like(big_logits, dtype=torch.bool, device="cpu")
for i in range(NUM_SEQS):
for j in range(MOVE_LIMIT):
valid_moves = valid_moves_list[i][j]
valid_moves = [model.to_single_token(m) for m in valid_moves]
is_valid_move[i, j, valid_moves] = True
is_valid_move = is_valid_move.cuda()
print(model.to_tokens(valid_moves_list[3][1])[:, 1], is_valid_move[3, 1], sep="\n")
tensor([27, 28], device='cuda:0')
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, True, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False], device='cuda:0')
def tensor_to_board(tensor):
board = torch.zeros(
size=tensor.shape[:-1] + (64,), device=tensor.device, dtype=tensor.dtype
)
board[...] = tensor[logitId2square]
return board.reshape(board.shape[:-1] + (8, 8)).flip(0)
imshow(
tensor_to_board(is_valid_move[5, 19]),
x=[chr(i) for i in range(ord("a"), ord("h") + 1)],
y=list(range(1, 9)),
)
print(board_seqs_tokens[5][19])
h7
# for pos in range(len(tokens)-1):
# legal_tokens = valid_moves_list[pos]
# if len(legal_tokens) == 0:
# continue
# valid_move_indices = model.to_tokens(legal_tokens)[:, 1]
# is_valid_move[0, pos, valid_move_indices] = True
# for pos in random.sample(range(len(tokens)-1), 3):
# print(torch.where(is_valid_move[0, pos]), model.to_tokens(
# valid_moves_list[pos])[:, 1], sep='\t')
for comp_idx in range(4, 7):
imshow(is_valid_move[TEXT_INDEX, comp_idx, logitId2square].reshape(8, 8).flip(0))
print("Num Valid moves: ", is_valid_move[TEXT_INDEX, comp_idx].sum(-1))
print("Max valid moves", is_valid_move.sum(-1).max())
print("Min valid moves", is_valid_move.sum(-1).min())
Num Valid moves: tensor(12, device='cuda:0')
Num Valid moves: tensor(2, device='cuda:0')
Num Valid moves: tensor(1, device='cuda:0') Max valid moves tensor(22, device='cuda:0') Min valid moves tensor(0, device='cuda:0')
stacked_big_resid, labels = big_cache.decompose_resid(-1, return_labels=True)
stacked_big_resid = big_cache.apply_ln_to_stack(stacked_big_resid, -1)
big_decomp_logits = stacked_big_resid @ model.W_U
print("big_decomp_logits.shape", big_decomp_logits.shape)
big_decomp_logits.shape torch.Size([26, 100, 127, 77])
def t(n):
return torch.tensor(n, device="cuda", dtype=torch.float32)
def sanity_check_tensor(tens):
print("sanity: ", tens.shape, tens.min(), tens.max())
correct_logits = (
torch.where(is_valid_move[None, :, :, :], big_decomp_logits, t(100)).min(-1).values
)[
..., :-1
] # Need to index because final token blows up
incorrect_logits = (
torch.where(~is_valid_move[None,], big_decomp_logits, t(-100)).max(-1).values
)[..., :-1]
sanity_check_tensor(correct_logits)
sanity_check_tensor(incorrect_logits)
sanity: torch.Size([26, 100, 126]) tensor(-5.8506, device='cuda:0') tensor(9.2044, device='cuda:0') sanity: torch.Size([26, 100, 126]) tensor(0.0050, device='cuda:0') tensor(6.7120, device='cuda:0')
scale = 2.0
imshow(
(correct_logits - incorrect_logits).mean(1).cpu(),
title="Min Correct - Max Incorrect",
y=labels,
xaxis="Position",
zmax=scale,
zmin=-scale,
)
Interestingly, 1_attn_out has nearly zero gradient across positions, whereas 0_attn_out gets worse over time
One thing to keep in mind here is the fact that there are farm more
incorrect tokens than there are correct tokens. Specifically, there are an average of 69.9 invalid tiles for every token, but 7.08 valid tiles. Actually, it's a good thing that the model is allocating lots of negative value toward the invalid moves, because that would signal the model is supressing invalid tiles.
Another intersting observation is the lattice of the negative results. If I remove every other component and every other positions, let's see what happens.
imshow(
((correct_logits - incorrect_logits)[1::2, :, 1::2]).mean(1).cpu(),
title="Min Correct - Max Incorrect (DEST tile)",
y=labels[1::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
((correct_logits - incorrect_logits)[1::2, :, 0::2]).mean(1).cpu(),
title="Min Correct - Max Incorrect (SOURCE tile)",
y=labels[1::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
# For the sake of it, here's the attention heads now, too
imshow(
((correct_logits - incorrect_logits)[0::2, :, 1::2]).mean(1).cpu(),
title="Min Correct - Max Incorrect (DEST tile)",
y=labels[0::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
((correct_logits - incorrect_logits)[0::2, :, 0::2]).mean(1).cpu(),
title="Min Correct - Max Incorrect (SOURCE tile)",
y=labels[0::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
1_attn_out is very intersting because it's essentially doing nothing?
Now I turn toward mean correct vs mean incorrect.
num_valid_moves = is_valid_move.sum(-1)
correct_logits_ave = (
torch.where(is_valid_move[None, :, :, :], big_decomp_logits, t(0)).sum(-1)
/ num_valid_moves
)[
..., :-1
] # Need to index because final token blows up
incorrect_logits_ave = (
torch.where(~is_valid_move[None,], big_decomp_logits, t(0)).sum(-1)
/ (77 - num_valid_moves)
)[..., :-1]
sanity_check_tensor(correct_logits_ave)
sanity_check_tensor(incorrect_logits_ave)
scale = 3.0
imshow(
(correct_logits_ave - incorrect_logits_ave).mean(-2).cpu(),
title="Mean Correct - Mean Incorrect",
y=labels,
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
((correct_logits_ave - incorrect_logits_ave)[1::2, :, 1::2]).mean(-2).cpu(),
title="MLP: Mean Correct - Mean Incorrect (DEST tile)",
y=labels[1::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
((correct_logits_ave - incorrect_logits_ave)[0::2, :, 1::2]).mean(-2).cpu(),
title="Attention: Mean Correct - Mean Incorrect (DEST tile)",
y=labels[0::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
((correct_logits_ave - incorrect_logits_ave)[1::2, :, 0::2]).mean(-2).cpu(),
title="MLP: Mean Correct - Mean Incorrect (SOURCE tile)",
y=labels[1::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
((correct_logits_ave - incorrect_logits_ave)[0::2, :, 0::2]).mean(-2).cpu(),
title="Attention: Mean Correct - Mean Incorrect (SOURCE tile)",
y=labels[0::2],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
(correct_logits_ave).mean(-2).cpu(),
title="Mean Correct",
y=labels,
xaxis="Position",
zmax=scale,
zmin=-scale,
)
imshow(
(incorrect_logits_ave).mean(-2).cpu(),
title="Mean Incorrect",
y=labels,
xaxis="Position",
zmax=scale,
zmin=-scale,
)
sanity: torch.Size([26, 100, 126]) tensor(-4.9369, device='cuda:0') tensor(9.2044, device='cuda:0') sanity: torch.Size([26, 100, 126]) tensor(-0.5004, device='cuda:0') tensor(0.0935, device='cuda:0')
From the above graphs, I see a few things:
- MLPs are doing a lot of the work
- Attention layers may be attending to multiple tiles that are relevant, but not necessarily valid
- All layers learn there's a difference between src and dest tiles
- Nearly all Attn layers differentiate white/black turns of source (as evidenced by the zipper-like pattern). This information appears, albiet less prominent, in the MLPs and Dest tiles. I think this makes some sense, because if selecting pieces at random on a chess board, white tends toward one side and black tends toward the other. So, selection should be biased according to the player whose turn it is. However, once a piece is selected, no matter whose turn it is, that piece must obey the rules of its movement.
# Zooming into the heads themselves now.
stacked_small_resid, labels = small_cache.get_full_resid_decomposition(
-1, return_labels=True, expand_neurons=False
)
stacked_small_resid = small_cache.apply_ln_to_stack(stacked_small_resid, -1)
small_decomp_logits = stacked_small_resid @ model.W_U
print(small_decomp_logits.shape)
Tried to stack head results when they weren't cached. Computing head results now torch.Size([159, 5, 127, 77])
num_valid_moves = is_valid_move.sum(-1)
correct_logits_ave = (
torch.where(
is_valid_move[None, : len(small_logits), :, :], small_decomp_logits, t(0)
).sum(-1)
/ num_valid_moves[: len(small_logits)]
)[
..., :-1
] # Need to index because final token blows up
incorrect_logits_ave = (
torch.where(
~is_valid_move[None, : len(small_logits), :, :], small_decomp_logits, t(-0)
).sum(-1)
/ (77 - num_valid_moves)[: len(small_logits)]
)[..., :-1]
sanity_check_tensor(correct_logits_ave)
sanity_check_tensor(incorrect_logits_ave)
attn_head_indices = torch.tensor(
[i for i, l in enumerate(labels) if "H" in l], device="cuda"
)
scale = 1
imshow(
torch.index_select(
(correct_logits_ave - incorrect_logits_ave), 0, attn_head_indices
)
.mean(-2)
.cpu(),
title="Mean Correct - Mean Incorrect",
y=[labels[i] for i in attn_head_indices],
xaxis="Position",
zmax=scale,
zmin=-scale,
)
sanity: torch.Size([159, 5, 126]) tensor(-2.8952, device='cuda:0') tensor(6.0168, device='cuda:0') sanity: torch.Size([159, 5, 126]) tensor(-0.4552, device='cuda:0') tensor(0.0615, device='cuda:0')
Activation patching¶
What I want to do here is to change the value of a single token, and then inspect how that changes the model's beliefs about valid follow- up moves
I believe the simplest way to find good candidates for corrpution is to focus on pawn captures.
from transformer_lens import patching
In game 0, at token 6: Blacks's pawn on d5 is blocked by White's pawn on D4. White moves c2 pawn to c5, allowing black to take on the next move. If we instead move c2c3, then the move d5c4 is invalidated. In game 0, pos 28, move queen to d5 instead of f6
GAME_INDEX = 0
plot_valid_moves(board_seqs_string[GAME_INDEX], is_valid_move[GAME_INDEX], tokenizer)
CUTOFF = 28 # 7
"""This is the index for the token we're going to corrupt"""
LOGIT_ID = model.to_single_token("e4") # d5 prev, can be e4 or d4
"""This is the logit_id for the tile whose validity changes post-
intervention."""
'This is the logit_id for the tile whose validity changes post-\nintervention.'
clean_tokens = board_seqs_tokens[GAME_INDEX][:CUTOFF]
corr_tokens = board_seqs_tokens[GAME_INDEX][:CUTOFF]
corr_tokens[-1] = "d5" # was c3
print(clean_tokens, corr_tokens, sep="\n")
corr_labels = [i[1] for i in model.to_str_tokens(corr_tokens)]
clean_labels = [i[1] for i in model.to_str_tokens(clean_tokens)]
clean_logits, clean_cache = model.run_with_cache(
model.to_tokens(clean_tokens)[:, 1].unsqueeze(0)
)
clean_logits = clean_logits[0, -1]
corr_logits, corr_cache = model.run_with_cache(
model.to_tokens(corr_tokens)[:, 1].unsqueeze(0)
)
corr_logits = corr_logits[0, -1]
print(clean_logits.shape, corr_logits.shape)
['<s>', 'd2', 'd4', 'd7', 'd5', 'c2', 'c4', 'd5', 'c4', 'e2', 'e4', 'e7', 'e5', 'g1', 'f3', 'f8', 'b4', 'c1', 'd2', 'b4', 'd2', 'b1', 'd2', 'b8', 'c6', 'f1', 'c4', 'd8'] ['<s>', 'd2', 'd4', 'd7', 'd5', 'c2', 'c4', 'd5', 'c4', 'e2', 'e4', 'e7', 'e5', 'g1', 'f3', 'f8', 'b4', 'c1', 'd2', 'b4', 'd2', 'b1', 'd2', 'b8', 'c6', 'f1', 'c4', 'd5'] torch.Size([77]) torch.Size([77])
plot_single_board(clean_logits.log_softmax(-1), title="clean", zmin=-6.0, zmax=0)
plot_single_board(corr_logits.log_softmax(-1), title="corr", zmin=-6.0, zmax=0)
def logit_metric(logits):
if len(logits.shape) == 3:
logits = logits[0]
if len(logits.shape) == 2:
logits = logits[-1]
# if len(logits) == 77:
# logits = logits
return logits[LOGIT_ID]
clean_baseline = logit_metric(clean_logits).detach()
corr_baseline = logit_metric(corr_logits).detach()
print("Clean Baseline:", clean_baseline.item())
print("corr Baseline:", corr_baseline.item())
metric = lambda logits: (logit_metric(logits) - corr_baseline) / (
clean_baseline - corr_baseline
)
Clean Baseline: 1.162939429283142 corr Baseline: 14.47728443145752
every_block_act_patch_result = patching.get_act_patch_block_every(
model,
torch.tensor(model.to_tokens(corr_tokens)[:, 1]).unsqueeze(0).cuda(),
clean_cache,
metric,
)
<ipython-input-43-01e37a703b37>:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
0%| | 0/336 [00:00<?, ?it/s]
0%| | 0/336 [00:00<?, ?it/s]
0%| | 0/336 [00:00<?, ?it/s]
imshow(
every_block_act_patch_result,
facet_col=0,
facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
title="Activation Patching Per Block",
xaxis="Position",
yaxis="Layer",
zmax=1,
zmin=-1,
x=[f"{tok}_{i}" for i, tok in enumerate(clean_labels)],
)
head_labels = [
f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
imshow(
(
clean_cache.stack_activation("pattern")[:, 0, :, -1]
- corr_cache.stack_activation("pattern")[:, 0, :, -1]
).reshape(144, -1),
color_continuous_midpoint=0.0,
y=head_labels,
x=clean_labels,
title="Activation Patching By Head",
)
DLA of diffed neurons¶
Here I'm applying DLA to the difference in neuron activations within the MLP layers between the clean and corrupted sequences. This reveals the neurons which are most affected by the change.
logit_vec = model.W_U[:, 1 + LOGIT_ID]
line_list = []
for mlp_index in range(model.cfg.n_layers):
mlp_diff = (
clean_cache["post", mlp_index, "mlp"][0, -1]
- corr_cache["post", mlp_index, "mlp"][0, -1]
)
line_list.append(mlp_diff * (model.blocks[mlp_index].mlp.W_out @ logit_vec))
line(line_list, title="Direct Logit Attribution of Diffed Neurons")
It's interesting to differentiate this behavior from that observed in the Othello paper. In the othello paper, most neurons had very small diffs (<0.2). Here, the differences are ~3x larger across the spectrum (~0.75). Still, several layers exhibited similar behavior to that observed in the othello paper: 11, 10.
Most interestingly, layer 10 had ~4 neurons which really mattered a lot.
Tangent on Analyzing Neurons¶
Looking at mean activations across 100 games. Neurons are sorted in order of increasing mean activation level. So position doesn't matter, just the relative differences between the layers. Observe just as in othello, this chess GPT's two final MLP layers have the biggest outliers.
line(
[
big_cache["post", i, "mlp"][:, 15:-15].mean([0, 1]).sort().values
for i in range(12)
],
title="Means of MLP Neurons (Center Pos)",
)
line(
[big_cache["post", i, "mlp"][:, :].mean([0, 1]).sort().values for i in range(12)],
title="Means of MLP Neurons (All Pos)",
)
DLA on top neurons¶
since there are some neurons which are much stronger than others, I want to look at those in more detail. Since it's clear many neurons have very low mean activation (plots above), I set a cutoff at 0.2 and look at those that remain.
for mlp_index in range(0, 12):
cutoff = 0.2 # chosen completely arbitrarily
neuron_means = big_cache["post", mlp_index, "mlp"][:, :].mean([0, 1])
neuron_indices = neuron_means > cutoff
print(neuron_indices)
W_out = model.blocks[mlp_index].mlp.W_out[neuron_indices]
print(f"{neuron_indices.sum().item()} Neurons above cutoff {cutoff}")
print((W_out @ model.W_U).shape)
imshow(
(W_out @ model.W_U), # .softmax(-1),
y=list(map(str, torch.arange(3072)[neuron_indices.cpu()].tolist())),
yaxis="Neuron",
title=f"Direct Logit Attr of top neurons in Layer {mlp_index}",
xaxis="tile",
# color_continuous_scale="viridis",
x=[f"{logitId2token[i]} ({i})" for i in range(77)],
)
tensor([False, False, False, ..., False, False, False], device='cuda:0') 13 Neurons above cutoff 0.2 torch.Size([13, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 27 Neurons above cutoff 0.2 torch.Size([27, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 47 Neurons above cutoff 0.2 torch.Size([47, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 23 Neurons above cutoff 0.2 torch.Size([23, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 25 Neurons above cutoff 0.2 torch.Size([25, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 29 Neurons above cutoff 0.2 torch.Size([29, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 25 Neurons above cutoff 0.2 torch.Size([25, 77])
tensor([False, True, False, ..., False, False, False], device='cuda:0') 36 Neurons above cutoff 0.2 torch.Size([36, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 18 Neurons above cutoff 0.2 torch.Size([18, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 28 Neurons above cutoff 0.2 torch.Size([28, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 78 Neurons above cutoff 0.2 torch.Size([78, 77])
tensor([False, False, False, ..., False, False, False], device='cuda:0') 183 Neurons above cutoff 0.2 torch.Size([183, 77])
plot_board_log_probs(board_seqs_string[GAME_INDEX], tokenizer, big_logits[GAME_INDEX])
for mlp_index in range(5, 8):
for i in [0,2,1,3]:
title_text = ["(White Start)", "(White End)","(Black Start)", "(Black End)"]
cutoff = 0.3 # chosen completely arbitrarily
neuron_means = big_cache["post", mlp_index, "mlp"][:, slice(i,127,4)].mean([0, 1])
neuron_indices = abs(neuron_means_diff) > cutoff
W_out = model.blocks[mlp_index].mlp.W_out[neuron_indices]
print(f"{neuron_indices.sum().item()} Neurons above cutoff {cutoff}")
print((W_out @ model.W_U).shape)
imshow(
(W_out @ model.W_U), # .softmax(-1),
y=list(map(str, torch.arange(3072)[neuron_indices.cpu()].tolist())),
yaxis="Neuron",
title=f"DLA: Mean output of top neurons in Layer {mlp_index} {title_text[i]}",
xaxis="tile",
# color_continuous_scale="viridis",
x=[f"{logitId2token[i]} ({i})" for i in range(77)],
)
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
17 Neurons above cutoff 0.3 torch.Size([17, 77])
big_cache["post", mlp_index, "mlp"][:, slice(3,127,4)].shape
torch.Size([100, 31, 3072])
for i in range(4):
for j in range(4):
diff = (big_cache["post", 5, "mlp"][:, slice(i,127,4)].mean([0, 1])-big_cache["post", 5, "mlp"][:, slice(j,127,4)].mean([0, 1]))
print(i,j,
float(diff.mean()),float(diff.max()),float(diff.min()),
sep = '||'
)
0||0||0.0||0.0||0.0 0||1||0.0022479859180748463||0.6269872784614563||-0.612943708896637 0||2||0.0007001814665272832||0.401024729013443||-0.5666456818580627 0||3||0.0024760710075497627||0.6269792914390564||-0.7314262390136719 1||0||-0.0022479859180748463||0.612943708896637||-0.6269872784614563 1||1||0.0||0.0||0.0 1||2||-0.0015478043351322412||0.5557777881622314||-0.7357940077781677 1||3||0.00022808529320172966||0.4563901424407959||-0.7119348049163818 2||0||-0.0007001814665272832||0.5666456818580627||-0.401024729013443 2||1||0.0015478043351322412||0.7357940077781677||-0.5557777881622314 2||2||0.0||0.0||0.0 2||3||0.0017758896574378014||0.7774345874786377||-0.5624321699142456 3||0||-0.0024760710075497627||0.7314262390136719||-0.6269792914390564 3||1||-0.00022808529320172966||0.7119348049163818||-0.4563901424407959 3||2||-0.0017758896574378014||0.5624321699142456||-0.7774345874786377 3||3||0.0||0.0||0.0